import pandas as pd
import numpy as np

from copy import deepcopy, copy
from sklearn.linear_model import LogisticRegression
import cvxpy as cvx


class logistic_bandits(object):

    def __init__(self, dict_params, dt_env, prefix_sep = "_zl_"):
        self.dict_params = dict_params
        self.current_iter = 0
        self.actions = []
        self.conversions = []
        self.pred_conv = []
        self.rewards = []
        self.costs2 = []
        self.costs1 = []
        self.prefix_sep = prefix_sep

        self._init_sequence_simu(dt_env)
        self._init_constant()

        np.random.seed(self.dict_params["seed"])

    def _init_constant(self):
        self.lmd = self.dict_params["lmd"]
        self.UCB_multiply = self.dict_params["UCB_multiply"]
        self.norm_costs1 = self.dict_params["norm_costs1"]

    def _init_sequence_simu(self, dt):
        if dt.shape[0] != self.dict_params["T"]:
            print(f"Sampling data to reach {self.dict_params['T']}")
            self.dt = dt.sample(self.dict_params["T"], replace=True).reset_index(drop = True)
        else:
            self.dt = dt

        dt_dummy = \
            pd.get_dummies(dt[[self.dict_params["var_rate"]] + self.dict_params["var_context"] + self.dict_params["var_base_reward_costs"]], prefix_sep = self.prefix_sep)
        var_add = [x for x in self.dict_params["var_model_onehot"] if x not in dt_dummy.columns]
        for var_ in var_add:
            dt_dummy[var_] = 0

        self.dt_dummy = dt_dummy[self.dict_params["var_model_onehot"] + self.dict_params["var_base_reward_costs"]]

    def _cal_budget(self):
        self.current_budget = (self.dict_params["budget"] - 1) * ((self.current_iter + 1)/self.dict_params["T"])

    def _get_context(self):
        self.current_context = deepcopy(self.dt.iloc[self.current_iter:self.current_iter+1].reset_index(drop = True))
        self.current_context_dummy = deepcopy(self.dt_dummy.iloc[self.current_iter:self.current_iter+1].reset_index(drop = True))

        self.backlog_context = deepcopy(self.dt.iloc[:self.current_iter+1].reset_index(drop = True))
        self.backlog_context_dummy = deepcopy(self.dt_dummy.iloc[:self.current_iter+1].reset_index(drop = True))

    def _cal_constant_UCB(self):
        self.current_constant_UCB = (np.log(self.current_iter + 1) + 1) * self.UCB_multiply

    def _update_vt(self):
        if self.current_iter == 0:
            self.vt_raw = np.outer(self.current_context_dummy.values,  self.current_context_dummy.values)
        else:
            self.vt_raw = self.vt_raw + np.outer(self.current_context_dummy.values, self.current_context_dummy.values)

        i = self.current_context_dummy.values.shape[1]
        self.vt = self.vt_raw + np.identity(i) * self.lmd

    def _cal_bonus_UCB(self, action, context, matrix_context):

        current_bonus_UCB = \
            self.current_constant_UCB * np.sqrt(np.sum(np.array(matrix_context * np.linalg.inv(self.vt)) * context, axis=1))

        if self.dict_params["verbose"] is True:
            if self.current_iter % 50 == 0:
                print(f"For action {action}: Average bonus {np.round(np.mean(current_bonus_UCB),4)}")

        return current_bonus_UCB

    def _take_action(self, policy = False, anull = False):
        if anull is True:
            action = -1
        elif policy is True:
            index_context_ = self.current_context["index_context_approx"].values[0]
            policy_ = self.current_policy
            list_acitons_ = [x for x in policy_.columns if x!="index_context_approx"]

            prob_actions_ = np.maximum(policy_.loc[policy_["index_context_approx"]==index_context_, list_acitons_].values[0], 0)
            prob_actions_ = list(prob_actions_/prob_actions_.sum())

            action = np.random.choice(a = list_acitons_, p = prob_actions_)
        else:
            list_acitons_ = self.dict_params["list_actions"][1:]
            action = np.random.choice(list_acitons_)

        self.actions.append(action)
        self.current_action = action
        return action

    def _update_action_context(self):
        if self.current_action != -1:
            self.current_context_dummy[self.dict_params["var_rate"]] = self.current_context_dummy[self.dict_params["var_rate"]] * (1 - self.current_action/100)

    def _get_reward_costs(self):
        if self.current_action == -1:
            self.conversions.append(0)
            self.pred_conv.append(0)
            self.rewards.append(0)
            self.costs2.append(0)
            self.costs1.append(0)
        else:
            discount_ = self.current_action / 100
            dt_ = deepcopy(self.current_context)
            dt_["pred_conv"] = \
                self.dict_params["model_conversion"].predict_proba(
                    self.current_context_dummy[self.dict_params["var_model_onehot"]])[:, 1]
            dt_["conversion"] = dt_.apply(lambda x:np.random.rand() < x["pred_conv"], axis=1) * 1
            dt_["reward"] = dt_["conversion"] * dt_["amount_norm"]
            dt_["cost2"] = dt_["conversion"] * dt_["discount_base_norm"] * discount_
#             dt_["cost2"] = (dt_[f"conversion"] * (discount_ > 0.2)*1)/10
            dt_["cost1"] = (dt_[f"conversion"] * discount_) / self.dict_params["norm_costs1"]

            if self.dict_params["verbose"] is True:
                if self.current_iter % 50 == 0:
                    print(f"current conversion: {np.round(dt_['conversion'][0], 4)}")
                    print(f"current reward: {np.round(dt_['reward'][0], 4)}")
                    print(f"current cost2: {np.round(dt_['cost2'][0], 4)}")
                    print(f"current cost1: {np.round(dt_['cost1'][0], 4)}")

            self.conversions.append(dt_["conversion"][0])
            self.pred_conv.append(dt_["pred_conv"][0])
            self.rewards.append(dt_["reward"][0])
            self.costs2.append(dt_["cost2"][0])
            self.costs1.append(dt_["cost1"][0])

    def _check_break_constraints(self):
        total_costs2 = np.sum(self.costs2)
        total_costs1 = np.sum(self.costs1)

        if (total_costs2 > self.dict_params["budget"] - 1) | (total_costs1 > self.dict_params["budget"] - 1):
            return True
        else:
            return False

    def _fit_logistic(self):
        discount_ = np.array(self.actions) / 100
        dt_dummy_ = deepcopy(self.backlog_context_dummy)
        dt_dummy_[self.dict_params["var_rate"]] = dt_dummy_[self.dict_params["var_rate"]] * (1 - discount_)
        dt_dummy_ = dt_dummy_[self.dict_params["var_model_onehot"] + self.dict_params["var_base_reward_costs"]]
        self.current_logistic = LogisticRegression(C = 1/self.lmd, solver = "lbfgs", fit_intercept = False)
        self.current_logistic.fit(X = dt_dummy_.values, y = self.conversions)

    def _calculate_gain_costs_matrix(self):
        self.backlog_context_dummy[f"{self.dict_params['var_rate']}_SAVE"] = \
            self.backlog_context_dummy[self.dict_params["var_rate"]]

        for i_ in self.dict_params["list_actions"]:
            # compute rate
            if i_ == -1:
                self.backlog_context[f"conversion_{i_}"] = 0
                self.backlog_context[f"volume_{i_}"] = 0
                self.backlog_context[f"discount_amount_{i_}"] = 0
                self.backlog_context[f"discount_sum_{i_}"] = 0
                self.backlog_context[f"UCB_{i_}"] = 0
            else:
                discount_ = i_ / 100
                self.backlog_context_dummy[self.dict_params["var_rate"]] = \
                    self.backlog_context_dummy[f"{self.dict_params['var_rate']}_SAVE"] * (1-discount_)

                context_ = self.backlog_context_dummy[self.dict_params["var_model_onehot"] + self.dict_params["var_base_reward_costs"]].values
                matrix_context_ = np.matrix(context_)

                self.backlog_context[f"conversion_{i_}"] = \
                    self.current_logistic.predict_proba(context_)[:, 1]
                current_bonus_UCB = self._cal_bonus_UCB(action = i_, context=context_, matrix_context=matrix_context_)
                self.backlog_context[f"UCB_{i_}"] = np.minimum(self.backlog_context[f"conversion_{i_}"] + current_bonus_UCB, 1)
                self.backlog_context[f"volume_{i_}"] = self.backlog_context[f"UCB_{i_}"]  * self.backlog_context["amount_norm"]
                self.backlog_context[f"discount_amount_{i_}"]  = \
                    self.backlog_context[f"UCB_{i_}"]* discount_ * self.backlog_context["discount_base_norm"]
                # self.backlog_context[f"discount_number_{i_}"] = (self.backlog_context[f"UCB_{i_}"] * (discount_> 0.2)*1) / 10
                self.backlog_context[f"discount_sum_{i_}"] = (self.backlog_context[f"UCB_{i_}"] * discount_) / self.dict_params["norm_costs1"]

        self.backlog_context_dummy[self.dict_params["var_rate"]] = \
            self.backlog_context_dummy[f"{self.dict_params['var_rate']}_SAVE"]
        self.backlog_context_dummy = self.backlog_context_dummy.drop(f"{self.dict_params['var_rate']}_SAVE", axis=1)

        var_target = \
            [x for x in self.backlog_context.columns if ("volume_" in x)|("discount_amount_" in x)|("discount_sum_" in x)|("conversion_" in x)|("UCB_" in x)]

        self.backlog_context["count"] = 1
        self.backlog_context = self.backlog_context.groupby("index_context_approx")[["count"] + var_target].sum().reset_index(drop = False)

        self.current_conversion_matrix = np.array(self.backlog_context[[f"conversion_{i}" for i in self.dict_params["list_actions"]]])
        self.current_UCB_matrix = np.array(self.backlog_context[[f"UCB_{i}" for i in self.dict_params["list_actions"]]])
        self.current_target_matrix = np.array(self.backlog_context[[f"volume_{i}" for i in self.dict_params["list_actions"]]])
        self.current_constraint2_matrix = np.array(self.backlog_context[[f"discount_amount_{i}" for i in self.dict_params["list_actions"]]])
        self.current_constraint1_matrix = np.array(self.backlog_context[[f"discount_sum_{i}" for i in self.dict_params["list_actions"]]])

    def _solve_optimization_policy(self):
        self.action_matrix = cvx.Variable(shape=(self.current_target_matrix.shape))

        obj_func = cvx.sum(cvx.multiply(self.action_matrix, self.current_target_matrix))

        constraint2 = \
            cvx.sum(cvx.multiply(self.action_matrix, self.current_constraint2_matrix)) <= self.current_budget

        constraint1 = \
            cvx.sum(cvx.multiply(self.action_matrix, self.current_constraint1_matrix)) <= self.current_budget

        constraint_non_negative = self.action_matrix >= 0
        constraint_sum_1 = cvx.sum(self.action_matrix, axis=1) == 1

        constraints = [constraint_non_negative, constraint_sum_1, constraint2, constraint1]

        obj = cvx.Maximize(obj_func)


        prob = cvx.Problem(obj, constraints)

        try:
            prob.solve(verbose = False)
        except:
            prob.solve(verbose = False, solver = "SCS")

        self.current_policy = pd.DataFrame(self.action_matrix.value, columns =  [x for x in self.dict_params["list_actions"]])
        self.current_policy["index_context_approx"] = self.backlog_context["index_context_approx"]

    def _run_one_iteration(self):
        self._get_context()
        self._cal_budget()

        if self.dict_params["verbose"] is True:
            if self.current_iter % 50 == 0:
                print(f"Iteration {self.current_iter }")

        break_constraint = self._check_break_constraints()

        if break_constraint == True:
            self._take_action(policy = False, anull = True)
        elif self.current_iter <= (self.dict_params["n_random_action"] - 1):
            self._take_action(policy = False, anull = False)
        else:
            self._calculate_gain_costs_matrix()
            self._solve_optimization_policy()
            self._take_action(policy = True, anull = False)

        if self.dict_params["verbose"] is True:
            if self.current_iter % 50 == 0:
                print(f"current action: {self.current_action}")

        self._update_action_context()
        self._get_reward_costs()

        if break_constraint is False:
            self._update_vt()
            if self.current_iter >= (self.dict_params["n_random_action"] - 1):
                self._fit_logistic()
                self._cal_constant_UCB()

        self.current_iter += 1

    def run_simulation(self):
        for i_ in range(self.dict_params["T"]):
            self._run_one_iteration()
            if self.dict_params["verbose"] is True:
                if self.current_iter % 50 == 0:
                    print(f"Cumulative Rewards: {np.round(np.sum(self.rewards), 4)}")
                    print(f"Cumulative Costs2: {np.round(np.sum(self.costs2), 4)}")
                    print(f"Cumulative Costs1: {np.round(np.sum(self.costs1), 4)}")
                    print("------------------------------------")